import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math

rng = np.random.default_rng(20260525)

# ============================================================
# CNVS TEST 9a / 9b / 9c
# Dependent Graph Propagation Monte Carlo
# ============================================================
# 9a = Sparse graph
# 9b = Small-world graph
# 9c = Dense graph
#
# 
# Victory:
# attacker wins only if:
# 1) all critical fragments are known by direct capture + graph inference;
# OR
# 2) remaining weighted entropy is guessed with probability 2^(-H_res).
# ============================================================


def make_graph(M, graph_type, avg_degree=4, rewire_p=0.08):
    adj = [set() for _ in range(M)]

    if graph_type == "sparse":
        for i in range(M):
            j = (i + 1) % M
            adj[i].add(j)
            adj[j].add(i)

        extra_edges = max(0, int(M * avg_degree / 2) - M)
        for _ in range(extra_edges):
            a, b = rng.integers(0, M, size=2)
            if a != b:
                adj[int(a)].add(int(b))
                adj[int(b)].add(int(a))

    elif graph_type == "small_world":
        k = max(2, avg_degree)
        if k % 2 == 1:
            k += 1

        half_k = k // 2

        for i in range(M):
            for d in range(1, half_k + 1):
                j = (i + d) % M
                adj[i].add(j)
                adj[j].add(i)

        for i in range(M):
            neighbors = list(adj[i])
            for old in neighbors:
                if rng.random() < rewire_p:
                    adj[i].discard(old)
                    adj[old].discard(i)

                    new = int(rng.integers(0, M))
                    if new != i:
                        adj[i].add(new)
                        adj[new].add(i)

    elif graph_type == "dense":
        p = min(0.5, avg_degree / max(1, M - 1))

        for i in range(M):
            for j in range(i + 1, M):
                if rng.random() < p:
                    adj[i].add(j)
                    adj[j].add(i)

        for i in range(M):
            if not adj[i]:
                j = (i + 1) % M
                adj[i].add(j)
                adj[j].add(i)

    else:
        raise ValueError("graph_type must be: sparse, small_world, dense")

    return [np.array(list(s), dtype=int) for s in adj]


def make_weights(M, heterogeneity=0.8):
    if heterogeneity <= 0:
        w = np.ones(M)
    else:
        w = rng.lognormal(mean=0.0, sigma=heterogeneity, size=M)

    return w / w.sum()


def propagate_inference(
    known_direct,
    adj,
    weights,
    rho_c,
    max_rounds=4,
    semantic_pressure=1.0
):
    M = len(known_direct)
    known = known_direct.copy()
    inferred_total = np.zeros(M, dtype=bool)

    for _ in range(max_rounds):
        newly = np.zeros(M, dtype=bool)

        for i in np.where(~known)[0]:
            neigh = adj[i]

            if neigh.size == 0:
                continue

            known_neigh = neigh[known[neigh]]

            if known_neigh.size == 0:
                continue

            local_signal = semantic_pressure * (
                known_neigh.size + M * weights[known_neigh].sum()
            )

            p_infer = 1.0 - (1.0 - rho_c) ** local_signal
            p_infer = float(np.clip(p_infer, 0.0, 1.0))

            if rng.random() < p_infer:
                newly[i] = True

        if not newly.any():
            break

        inferred_total |= newly
        known |= newly

    return inferred_total, known


def one_trial(
    pool,
    M,
    I_G,
    q,
    adj,
    weights,
    rho_c,
    max_rounds,
    semantic_pressure
):
    N = len(pool)
    k = int(round(q * N))

    captured = rng.choice(pool, size=k, replace=False)

    known_direct = np.zeros(M, dtype=bool)
    known_direct[np.unique(captured)] = True

    inferred, known_total = propagate_inference(
        known_direct=known_direct,
        adj=adj,
        weights=weights,
        rho_c=rho_c,
        max_rounds=max_rounds,
        semantic_pressure=semantic_pressure
    )

    missing = ~known_total

    K_direct = weights[known_direct].sum()
    K_inferential = weights[inferred].sum()
    K_adv = weights[known_total].sum()

    H_res = weights[missing].sum() * I_G

    P_guess = 0.0 if H_res > 1024 else 2.0 ** (-H_res)

    complete_by_knowledge = bool(known_total.all())
    complete_by_entropy = bool(rng.random() < P_guess)

    win = complete_by_knowledge or complete_by_entropy

    return {
        "K_direct": K_direct,
        "K_inferential": K_inferential,
        "K_adv": K_adv,
        "H_res": H_res,
        "P_guess": P_guess,
        "missing_count": int(missing.sum()),
        "inferred_count": int(inferred.sum()),
        "known_count": int(known_total.sum()),
        "win": win,
        "complete_by_knowledge": complete_by_knowledge,
        "complete_by_entropy": complete_by_entropy,
    }


def simulate_graph_test(
    label,
    graph_type,
    I_G=1000,
    I0=5,
    lambda_factor=2,
    rho_c=0.002,
    avg_degree=4,
    heterogeneity=0.8,
    max_rounds=4,
    semantic_pressure=1.0,
    q_values=None,
    trials=5000
):
    if q_values is None:
        q_values = np.linspace(0.01, 0.99, 100)

    M = math.ceil(I_G / I0)
    pool = np.repeat(np.arange(M), lambda_factor)
    N = len(pool)

    adj = make_graph(
        M=M,
        graph_type=graph_type,
        avg_degree=avg_degree
    )

    weights = make_weights(
        M=M,
        heterogeneity=heterogeneity
    )

    rows = []

    for q in q_values:
        trial_rows = [
            one_trial(
                pool=pool,
                M=M,
                I_G=I_G,
                q=float(q),
                adj=adj,
                weights=weights,
                rho_c=rho_c,
                max_rounds=max_rounds,
                semantic_pressure=semantic_pressure
            )
            for _ in range(trials)
        ]

        df = pd.DataFrame(trial_rows)

        rows.append({
            "test": label,
            "graph_type": graph_type,
            "q": float(q),
            "I_G_bits": I_G,
            "I0_bits": I0,
            "lambda_factor": lambda_factor,
            "rho_c": rho_c,
            "avg_degree": avg_degree,
            "heterogeneity": heterogeneity,
            "max_rounds": max_rounds,
            "semantic_pressure": semantic_pressure,
            "M_fragments": M,
            "N_total": N,
            "P_win": float(df["win"].mean()),
            "P_complete_by_knowledge": float(df["complete_by_knowledge"].mean()),
            "P_complete_by_entropy": float(df["complete_by_entropy"].mean()),
            "mean_K_direct": float(df["K_direct"].mean()),
            "mean_K_inferential": float(df["K_inferential"].mean()),
            "mean_K_adv": float(df["K_adv"].mean()),
            "p95_K_adv": float(df["K_adv"].quantile(0.95)),
            "mean_H_res": float(df["H_res"].mean()),
            "p05_H_res": float(df["H_res"].quantile(0.05)),
            "mean_missing_count": float(df["missing_count"].mean()),
            "mean_inferred_count": float(df["inferred_count"].mean()),
            "mean_known_count": float(df["known_count"].mean()),
            "mean_P_guess": float(df["P_guess"].mean()),
        })

    return pd.DataFrame(rows)


# ============================================================
# HIGH-QUALITY PARAMETERS FOR COLAB
# ============================================================

Q_VALUES = np.linspace(0.01, 0.99, 100)
TRIALS = 5000

TESTS = [
    {
        "label": "Test 9a — Sparse dependent graph",
        "graph_type": "sparse",
        "avg_degree": 3,
        "rho_c": 0.0015,
        "semantic_pressure": 0.8,
        "max_rounds": 3,
    },
    {
        "label": "Test 9b — Small-world dependent graph",
        "graph_type": "small_world",
        "avg_degree": 4,
        "rho_c": 0.0020,
        "semantic_pressure": 1.0,
        "max_rounds": 4,
    },
    {
        "label": "Test 9c — Dense dependent graph",
        "graph_type": "dense",
        "avg_degree": 10,
        "rho_c": 0.0025,
        "semantic_pressure": 1.2,
        "max_rounds": 5,
    },
]

all_results = []

for test in TESTS:
    print("Running:", test["label"])

    df_test = simulate_graph_test(
        label=test["label"],
        graph_type=test["graph_type"],
        I_G=1000,
        I0=5,
        lambda_factor=2,
        rho_c=test["rho_c"],
        avg_degree=test["avg_degree"],
        heterogeneity=0.8,
        max_rounds=test["max_rounds"],
        semantic_pressure=test["semantic_pressure"],
        q_values=Q_VALUES,
        trials=TRIALS
    )

    all_results.append(df_test)

results = pd.concat(all_results, ignore_index=True)

results.to_csv(
    "cnvs_test9_dependent_graphs_results.csv",
    index=False
)

results.to_excel(
    "cnvs_test9_dependent_graphs_results.xlsx",
    index=False
)

# ============================================================
# PLOT 1 — COMPLETE RECONSTRUCTION PROBABILITY
# ============================================================

plt.figure(figsize=(12, 7))

for label, g in results.groupby("test"):
    plt.plot(
        g["q"],
        g["P_win"],
        linewidth=2.5,
        label=label
    )

plt.axvline(
    1/3,
    color="black",
    linestyle="--",
    alpha=0.7,
    label="BFT reference line (visual only)"
)

plt.xlabel("Fraction of network physically compromised by attacker (q)")
plt.ylabel("Probability of complete unauthorized reconstruction")
plt.title("CNVS Test 9a/9b/9c — Dependent Graph Propagation")
plt.grid(True, linestyle=":", alpha=0.7)
plt.legend(fontsize=9, loc="upper left")
plt.tight_layout()
plt.savefig("test9_probability_complete_reconstruction.png", dpi=200)
plt.show()

# ============================================================
# PLOT 2 — K_adv BY GRAPH TOPOLOGY
# ============================================================

plt.figure(figsize=(12, 7))

for label, g in results.groupby("test"):
    plt.plot(
        g["q"],
        g["mean_K_adv"],
        linewidth=2.5,
        label=label
    )

plt.axvline(
    1/3,
    color="black",
    linestyle="--",
    alpha=0.7,
    label="BFT reference line (visual only)"
)

plt.xlabel("Fraction of network physically compromised by attacker (q)")
plt.ylabel("Mean weighted adversarial knowledge K_adv")
plt.title("CNVS Test 9a/9b/9c — K_adv by Graph Topology")
plt.grid(True, linestyle=":", alpha=0.7)
plt.legend(fontsize=9, loc="upper left")
plt.tight_layout()
plt.savefig("test9_weighted_kadv.png", dpi=200)
plt.show()

# ============================================================
# PLOT 3 — RESIDUAL ENTROPY BY GRAPH TOPOLOGY
# ============================================================

plt.figure(figsize=(12, 7))

for label, g in results.groupby("test"):
    plt.plot(
        g["q"],
        g["mean_H_res"],
        linewidth=2.5,
        label=label
    )

plt.axvline(
    1/3,
    color="black",
    linestyle="--",
    alpha=0.7,
    label="BFT reference line (visual only)"
)

plt.xlabel("Fraction of network physically compromised by attacker (q)")
plt.ylabel("Mean weighted residual entropy H_res (bits)")
plt.title("CNVS Test 9a/9b/9c — Residual Entropy by Graph Topology")
plt.grid(True, linestyle=":", alpha=0.7)
plt.legend(fontsize=9, loc="upper right")
plt.tight_layout()
plt.savefig("test9_weighted_residual_entropy.png", dpi=200)
plt.show()

print(results.head())
print("Saved files:")
print("cnvs_test9_dependent_graphs_results.csv")
print("cnvs_test9_dependent_graphs_results.xlsx")
print("test9_probability_complete_reconstruction.png")
print("test9_weighted_kadv.png")
print("test9_weighted_residual_entropy.png")